/*
 * Copyright (c) 2023, Peter Abeles. All Rights Reserved.
 *
 * This file is part of Efficient Java Matrix Library (EJML).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.ejml.dense.row.mult;

import org.ejml.CodeGeneratorBase;

import java.io.FileNotFoundException;


/**
 * <p>
 * This class generates code for various matrix matrix multiplication operations. The code associated
 * with these operators is often only slightly different from each other. So to remove some
 * of the tediousness of writing and maintaining it is autogenerated.
 * <p>
 * <p>
 * To create MatrixMatrixMult_DDRM simply run this application and copy it to the appropriate location.
 * </p>
 *
 * @author Peter Abeles
 */
public class GenerateMatrixMatrixMult_DDRM extends CodeGeneratorBase {

    @Override
    public void generate() throws FileNotFoundException {
        setOutputFile("MatrixMatrixMult_DDRM");
        String preamble = 
                "import org.ejml.MatrixDimensionException;\n" +
                "import org.ejml.data.DMatrix1Row;\n" +
                "import org.ejml.UtilEjml;\n" +
                "import org.ejml.dense.row.CommonOps_DDRM;\n" +
                "import org.jetbrains.annotations.Nullable;\n" +
                "//CONCURRENT_INLINE import org.ejml.concurrency.EjmlConcurrency;\n" +
                "\n" +
                "/**\n" +
                " * <p>\n" +
                " * This class contains various types of matrix matrix multiplication operations for {@link DMatrix1Row}.\n" +
                " * </p>\n" +
                " * <p>\n" +
                " * Two algorithms that are equivalent can often have very different runtime performance.\n" +
                " * This is because of how modern computers uses fast memory caches to speed up reading/writing to data.\n" +
                " * Depending on the order in which variables are processed different algorithms can run much faster than others,\n" +
                " * even if the number of operations is the same.\n" +
                " * </p>\n" +
                " *\n" +
                " * <p>\n" +
                " * Algorithms that are labeled as 'reorder' are designed to avoid caching jumping issues, some times at the cost\n" +
                " * of increasing the number of operations. This is important for large matrices. The straight forward\n" +
                " * implementation seems to be faster for small matrices.\n" +
                " * </p>\n" +
                " *\n" +
                " * <p>\n" +
                " * Algorithms that are labeled as 'aux' use an auxiliary array of length n. This array is used to create\n" +
                " * a copy of an out of sequence column vector that is referenced several times. This reduces the number\n" +
                " * of cache misses. If the 'aux' parameter passed in is null then the array is declared internally.\n" +
                " * </p>\n" +
                " *\n" +
                " * <p>\n" +
                " * Typically the straight forward implementation runs about 30% faster on smaller matrices and\n" +
                " * about 5 times slower on larger matrices. This is all computer architecture and matrix shape/size specific.\n" +
                " * </p>\n" +
                standardClassDocClosing("Peter Abeles") +
                "public class "+className+" {\n";

        out.print(preamble);

        for( int i = 0; i < 2; i++ ) {
            boolean alpha = i == 1;
            for( int j = 0; j < 2; j++ ) {
                boolean add = j == 1;
                printMult_reroder(alpha,add);
                out.print("\n");
                printMult_small(alpha,add);
                out.print("\n");
                printMult_aux(alpha,add);
                out.print("\n");
                printMultTransA_reorder(alpha,add);
                out.print("\n");
                printMultTransA_small(alpha,add);
                out.print("\n");
                printMultTransAB(alpha,add);
                out.print("\n");
                printMultTransAB_aux(alpha,add);
                out.print("\n");
                printMultTransB(alpha,add);
                // don't print if very last one to avoid extra space at end
                if( i != 1 || j != 1)
                    out.print("\n");
            }
        }
        out.println("}");
        out.close();
    }

    private String makeBoundsCheck(boolean tranA, boolean tranB, String auxLength, boolean reshape)
    {
        String a_numCols = tranA ? "A.numRows" : "A.numCols";
        String a_numRows = tranA ? "A.numCols" : "A.numRows";
        String b_numCols = tranB ? "B.numRows" : "B.numCols";
        String b_numRows = tranB ? "B.numCols" : "B.numRows";

        String ret =
                "        UtilEjml.assertTrue(A != C && B != C, \"Neither 'A' or 'B' can be the same matrix as 'C'\");\n" +
                "        UtilEjml.assertShape("+a_numCols+", "+b_numRows+", \"The 'A' and 'B' matrices do not have compatible dimensions\");\n";
        if( reshape)
            ret += "        C.reshape("+a_numRows+", "+b_numCols+");\n";
        else {
            ret += "        UtilEjml.assertShape(" + a_numRows + " == C.numRows && " + b_numCols +
                    " == C.numCols, \"C is not compatible with A and B\");\n";
        }

        ret += "\n";

        if( auxLength != null ) {
            ret += "        if (aux == null) aux = new double["+auxLength+"];\n\n";
        }

        return ret;
    }

    private String handleZeros( boolean add ) {

        String fill = add ? "" : "            CommonOps_DDRM.fill(C, 0);\n";

        String ret =
                "        if (A.numCols == 0 || A.numRows == 0) {\n" +
                fill +
                "            return;\n" +
                "        }\n";
        return ret;
    }

    private String makeComment( String nameOp , boolean hasAlpha )
    {
        String a = hasAlpha ? "double, " : "";
        String inputs = "("+a+"org.ejml.data.DMatrix1Row, org.ejml.data.DMatrix1Row, org.ejml.data.DMatrix1Row)";


        String ret =
                "    /**\n" +
                "     * @see CommonOps_DDRM#"+nameOp+inputs+"\n" +
                "     */\n";
        return ret;
    }

    private String makeHeader(String nameOp, String variant,
                              boolean add, boolean hasAlpha, boolean hasAux,
                              boolean tranA, boolean tranB)
    {
        if( add ) nameOp += "Add";

        // make the op name
        if( tranA && tranB ) {
            nameOp += "TransAB";
        } else if( tranA ) {
            nameOp += "TransA";
        } else if( tranB ) {
            nameOp += "TransB";
        }

        String ret = makeComment(nameOp,hasAlpha)+
                     "    public static void "+nameOp;

        if( variant != null ) ret += "_"+variant+"( ";
        else ret += "( ";

        if( hasAlpha ) ret += "double alpha, ";

        if( hasAux ) {
            ret += "DMatrix1Row A, DMatrix1Row B, DMatrix1Row C, @Nullable double[] aux ) {\n";
        } else {
            ret += "DMatrix1Row A, DMatrix1Row B, DMatrix1Row C ) {\n";
        }

        return ret;
    }

    public void printMult_reroder( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","reorder",add,alpha, false, false,false);

        if( alpha ) {
            valLine = "valA = alpha*A.data[indexA++];\n";
        } else {
            valLine = "valA = A.data[indexA++];\n";
        }

        String assignment = add ? "plus" : "set";

        String foo =
                header + makeBoundsCheck(false,false, null,!add)+handleZeros(add) +
                        "        final int endOfKLoop = B.numRows*B.numCols;\n"+
                        "\n" +
                        "        //CONCURRENT_BELOW EjmlConcurrency.loopFor(0, A.numRows, i -> {\n" +
                        "        for (int i = 0; i < A.numRows; i++) {\n" +
                        "            int indexCbase = i*C.numCols;\n" +
                        "            int indexA = i*A.numCols;\n" +
                        "\n"+
                        "            // need to assign C.data to a value initially\n" +
                        "            int indexB = 0;\n" +
                        "            int indexC = indexCbase;\n" +
                        "            int end = indexB + B.numCols;\n" +
                        "\n" +
                        "            double "+valLine +
                        "\n" +
                        "            while (indexB < end) {\n" +
                        "                C."+assignment+"(indexC++, valA*B.data[indexB++]);\n" +
                        "            }\n" +
                        "\n" +
                        "            // now add to it\n"+
                        "            while (indexB != endOfKLoop) { // k loop\n"+
                        "                indexC = indexCbase;\n" +
                        "                end = indexB + B.numCols;\n" +
                        "\n" +
                        "                "+valLine+
                        "\n" +
                        "                while (indexB < end) { // j loop\n" +
                        "                    C.data[indexC++] += valA*B.data[indexB++];\n" +
                        "                }\n" +
                        "            }\n" +
                        "        }\n" +
                        "        //CONCURRENT_ABOVE });\n" +
                        "    }\n";

        out.print(foo);
    }

    public void printMult_small( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","small",add,alpha, false, false,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "                C."+assignment+"(cIndex++, alpha*total);\n";
        } else {
            valLine = "                C."+assignment+"(cIndex++, total);\n";
        }

        String foo =
                header + makeBoundsCheck(false,false, null,!add)+
                        "        //CONCURRENT_BELOW EjmlConcurrency.loopFor(0, A.numRows, i -> {\n" +
                        "        for (int i = 0; i < A.numRows; i++) {\n" +
                        "            int cIndex = i*B.numCols;\n" +
                        "            int aIndexStart = i*A.numCols;\n" +
                        "            for (int j = 0; j < B.numCols; j++) {\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                int indexA = aIndexStart;\n" +
                        "                int indexB = j;\n" +
                        "                int end = indexA + B.numRows;\n" +
                        "                while (indexA < end) {\n" +
                        "                    total += A.data[indexA++]*B.data[indexB];\n" +
                        "                    indexB += B.numCols;\n" +
                        "                }\n" +
                        "\n" +
                        valLine +
                        "            }\n" +
                        "        }\n" +
                        "        //CONCURRENT_ABOVE });\n" +
                        "    }\n";
        out.print(foo);
    }

    public void printMult_aux( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","aux",add,alpha, true, false,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "                C."+assignment+"(i*C.numCols + j, alpha*total);\n";
        } else {
            valLine = "                C."+assignment+"(i*C.numCols + j, total);\n";
        }

        String foo =
                header + makeBoundsCheck(false,false, "B.numRows",!add)+
                        "        for (int j = 0; j < B.numCols; j++) {\n" +
                        "            // create a copy of the column in B to avoid cache issues\n" +
                        "            for (int k = 0; k < B.numRows; k++) {\n" +
                        "                aux[k] = B.unsafe_get(k, j);\n" +
                        "            }\n" +
                        "\n" +
                        "            int indexA = 0;\n" +
                        "            for (int i = 0; i < A.numRows; i++) {\n" +
                        "                double total = 0;\n" +
                        "                for (int k = 0; k < B.numRows; ) {\n" +
                        "                    total += A.data[indexA++]*aux[k++];\n" +
                        "                }\n" +
                        valLine +
                        "            }\n" +
                        "        }\n" +
                        "    }\n";
        out.println("    //CONCURRENT_OMIT_BEGIN");
        out.print(foo);
        out.println("    //CONCURRENT_OMIT_END");
    }

    public void printMultTransA_reorder( boolean alpha , boolean add ) {
        String header,valLine1,valLine2;

        header = makeHeader("mult","reorder",add,alpha, false, true,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine1 = "valA = alpha*A.data[i];\n";
            valLine2 = "valA = alpha*A.unsafe_get(k, i);\n";
        } else {
            valLine1 = "valA = A.data[i];\n";
            valLine2 = "valA = A.unsafe_get(k, i);\n";
        }

        String foo =
                header + makeBoundsCheck(true,false, null,!add)+handleZeros(add)+
                        "        //CONCURRENT_BELOW EjmlConcurrency.loopFor(0, A.numCols, i -> {\n" +
                        "        for (int i = 0; i < A.numCols; i++) {\n" +
                        "            int indexC_start = i*C.numCols;\n" +
                        "\n" +
                        "            // first assign R\n" +
                        "            double " +valLine1+
                        "            int indexB = 0;\n" +
                        "            int end = indexB + B.numCols;\n" +
                        "            int indexC = indexC_start;\n" +
                        "            while (indexB < end) {\n" +
                        "                C."+assignment+"(indexC++, valA*B.data[indexB++]);\n" +
                        "            }\n" +
                        "            // now increment it\n" +
                        "            for (int k = 1; k < A.numRows; k++) {\n" +
                        "                " +valLine2+
                        "                end = indexB + B.numCols;\n" +
                        "                indexC = indexC_start;\n" +
                        "                // this is the loop for j\n" +
                        "                while (indexB < end) {\n" +
                        "                    C.data[indexC++] += valA*B.data[indexB++];\n" +
                        "                }\n" +
                        "            }\n" +
                        "        }\n" +
                        "        //CONCURRENT_ABOVE });\n" +
                        "    }\n";
        out.print(foo);
    }

    public void printMultTransA_small( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","small",add,alpha, false, true,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "C."+assignment+"(cIndex++, alpha*total);\n";
        } else {
            valLine = "C."+assignment+"(cIndex++, total);\n";
        }

        String foo =
                header + makeBoundsCheck(true,false, null,!add)+
                        "        //CONCURRENT_BELOW EjmlConcurrency.loopFor(0, A.numCols, i -> {\n" +
                        "        for (int i = 0; i < A.numCols; i++) {\n" +
                        "            int cIndex = i*B.numCols;\n" +
                        "            for (int j = 0; j < B.numCols; j++) {\n" +
                        "                int indexA = i;\n" +
                        "                int indexB = j;\n" +
                        "                int end = indexB + B.numRows*B.numCols;\n" +
                        "\n" +
                        "                double total = 0;\n" +
                        "                // loop for k\n" +
                        "                for (; indexB < end; indexB += B.numCols) {\n" +
                        "                    total += A.data[indexA]*B.data[indexB];\n" +
                        "                    indexA += A.numCols;\n" +
                        "                }\n" +
                        "\n" +
                        "                "+valLine +
                        "            }\n" +
                        "        }\n" +
                        "        //CONCURRENT_ABOVE });\n" +
                        "    }\n";

         out.print(foo);
    }

    public void printMultTransB( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult",null,add,alpha, false, false,true);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "C."+assignment+"(cIndex++, alpha*total);\n";
        } else {
            valLine = "C."+assignment+"(cIndex++, total);\n";
        }

        String foo =
                header + makeBoundsCheck(false,true, null,!add)+
                        "        //CONCURRENT_BELOW EjmlConcurrency.loopFor(0, A.numRows, xA -> {\n" +
                        "        for (int xA = 0; xA < A.numRows; xA++) {\n" +
                        "            int cIndex = xA*B.numRows;\n" +
                        "            int aIndexStart = xA*B.numCols;\n" +
                        "            int end = aIndexStart + B.numCols;\n" +
                        "            int indexB = 0;\n"+
                        "            for (int xB = 0; xB < B.numRows; xB++) {\n" +
                        "                int indexA = aIndexStart;\n" +
                        "\n" +
                        "                double total = 0;\n" +
                        "                while (indexA < end) {\n" +
                        "                    total += A.data[indexA++]*B.data[indexB++];\n" +
                        "                }\n" +
                        "\n" +
                        "                "+valLine +
                        "            }\n" +
                        "        }\n" +
                        "        //CONCURRENT_ABOVE });\n" +
                        "    }\n";
        out.print(foo);
    }

    public void printMultTransAB( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult",null,add,alpha, false, true,true);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "C."+assignment+"(cIndex++, alpha*total);\n";
        } else {
            valLine = "C."+assignment+"(cIndex++, total);\n";
        }

        String foo =
                header + makeBoundsCheck(true,true, null,!add)+
                        "        //CONCURRENT_BELOW EjmlConcurrency.loopFor(0, A.numCols, i -> {\n" +
                        "        for (int i = 0; i < A.numCols; i++) {\n" +
                        "            int cIndex = i*B.numRows;\n" +
                        "            int indexB = 0;\n"+
                        "            for (int j = 0; j < B.numRows; j++) {\n" +
                        "                int indexA = i;\n" +
                        "                int end = indexB + B.numCols;\n" +
                        "\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                while (indexB < end) {\n" +
                        "                    total += A.data[indexA]*B.data[indexB++];\n" +
                        "                    indexA += A.numCols;\n" +
                        "                }\n" +
                        "\n" +
                        "                "+valLine+
                        "            }\n" +
                        "        }\n"+
                        "        //CONCURRENT_ABOVE });\n" +
                        "    }\n";
        out.print(foo);
    }

    public void printMultTransAB_aux( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","aux",add,alpha, true, true,true);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "C."+assignment+"(indexC++, alpha*total);\n";
        } else {
            valLine = "C."+assignment+"(indexC++, total);\n";
        }

        String foo =
                header + makeBoundsCheck(true,true, "A.numRows",!add)+handleZeros(add)+
                        "        int indexC = 0;\n" +
                        "        for (int i = 0; i < A.numCols; i++) {\n" +
                        "            for (int k = 0; k < B.numCols; k++) {\n" +
                        "                aux[k] = A.unsafe_get(k, i);\n" +
                        "            }\n" +
                        "\n" +
                        "            for (int j = 0; j < B.numRows; j++) {\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                for (int k = 0; k < B.numCols; k++) {\n" +
                        "                    total += aux[k]*B.unsafe_get(j, k);\n" +
                        "                }\n" +
                        "                "+valLine +
                        "            }\n" +
                        "        }\n"+
                        "    }\n";
        out.println("    //CONCURRENT_OMIT_BEGIN");
        out.print(foo);
        out.println("    //CONCURRENT_OMIT_END");
    }

    public static void main(String[] args) throws FileNotFoundException {
        new GenerateMatrixMatrixMult_DDRM().generate();
    }
}
